from networks.fcnet import FCNet
from networks.cscdnet import CSCDNet
from networks.fcnet import FCFMNet


def net_factory(net_type="fcnet", in_chns=3, class_num=3):
    if net_type == "fcnet":
        net = FCNet(in_chns, class_num).cuda()

    if net_type == "cscdnet":
        net = CSCDNet(inc=6, outc=1).cuda()

    if net_type == "fcfmnet":
        net = FCFMNet(in_chns, class_num).cuda()

    return net